# from utils.DataLoader import DataLoader
# import utils
# import os
# import torch
# import torchvision
# import random
# from torchvision import transforms as transforms
# import numpy as np
# import copy
# from torch.utils.data import ConcatDataset
#
# class DataLoader_fashion(DataLoader):
#     def __init__(self,
#                  split_num=200,
#                  pick_num=2,
#                  batch_size=100,
#                  input_require_shape=None,
#                  shuffle=True,
#                  pool_size=None,
#                  recreate=True,
#                  params=None,
#                  *args,
#                  **kwargs):
#         if params is not None:
#             if pool_size is not None:
#                 split_num = pool_size * params['N']
#                 pick_num = params['N']
#             else:
#                 split_num = int(params['C'] * params['N'])
#                 pick_num = params['N']
#             batch_size = params['batch_size']
#         if split_num % pick_num != 0:
#             raise RuntimeError('split_num must be divisible by the number of pick_num.')
#         pool_size = split_num // pick_num
#         name = 'Fashion_pool_' + str(pool_size) + 'split_' + str(split_num) + 'pick' + str(pick_num) + '_batchsize_' + str(batch_size) + '_sort_split_input_require_shape_' + str(input_require_shape)
#         nickname = 'fashion B' + str(batch_size) + ' S'+ str(split_num) + ' P' + str(pick_num) + ' N' + str(pool_size)
#         super().__init__(name, nickname, pool_size, batch_size, input_require_shape)
#
#         file_path = utils.pool_folder_path + name + '.npy'
#
#         if os.path.exists(file_path) and (recreate == False):
#             data_loader = np.load(file_path, allow_pickle=True).item()
#             for attr in list(data_loader.__dict__.keys()):
#                 setattr(self, attr, data_loader.__dict__[attr])
#             print('Successfully Read the Data Pool.')
#         else:
#             transform = transforms.Compose(
#                 [transforms.ToTensor()])
#             trainset = torchvision.datasets.FashionMNIST(root=utils.data_folder_path, train=True,
#                                                     download=True, transform=transform)
#             trainloader = torch.utils.data.DataLoader(trainset, batch_size=trainset.data.shape[0],
#                                                       shuffle=True, num_workers=1)
#             testset = torchvision.datasets.FashionMNIST(root=utils.data_folder_path, train=False,
#                                                    download=True, transform=transform)
#             testloader = torch.utils.data.DataLoader(testset, batch_size=testset.data.shape[0],
#                                                      shuffle=False, num_workers=1)
#             # global_training_data = torch.utils.data.DataLoader(copy.deepcopy(trainset),
#             #                                                    batch_size=self.batch_size,
#             #                                                    shuffle=True, num_workers=1)
#             # global_test_data = torch.utils.data.DataLoader(copy.deepcopy(testset),
#             #                                                batch_size=self.batch_size,
#             #                                                shuffle=False, num_workers=1)
#             # # modify
#             # num_samples = len(trainset)
#             # noise_ratio = 0.5
#             # num_samples_to_modify = int(num_samples * noise_ratio)
#             # indices_to_modify = np.random.choice(num_samples, num_samples_to_modify, replace=False)
#             # num_classes = 10
#             # for idx in indices_to_modify:
#             #     new_label = np.random.randint(0, num_classes)  # Generate random incorrect label
#             #     while new_label == trainset.targets[idx]:  # Ensure the new label is different from the original one
#             #         new_label = np.random.randint(0, num_classes)
#             #     trainset.targets[idx] = new_label
#             totalset = ConcatDataset([trainset, testset])
#             totalloader = torch.utils.data.DataLoader(totalset, batch_size=len(totalset),
#                                                       shuffle=True, num_workers=1)
#             for i, (input_data, targets) in enumerate(trainloader):
#                 train_input_data = input_data
#                 train_target_data = targets
#             for i, (input_data, targets) in enumerate(testloader):
#                 test_input_data = input_data
#                 test_target_data = targets
#             for i, (input_data, targets) in enumerate(totalloader):
#                 total_input_data = input_data
#                 total_target_data = targets
#
#             self.cal_data_shape(train_input_data.shape)
#
#             self.target_class_num = 10
#
#             self.global_training_data = []
#             self.global_test_data = []
#             # for (input_data, targets) in global_training_data:
#             #     self.global_training_data.append((input_data.reshape([-1] + self.input_data_shape), targets))
#             # for (input_data, targets) in global_test_data:
#             #     self.global_test_data.append((input_data.reshape([-1] + self.input_data_shape), targets))
#             self.total_training_number = len(trainset)
#             self.total_test_number = len(testset)
#             self.output_size = 10
#             self.model4data = 'mlp'
#             self.task_name = 'fashion_classification'
#
#             def create_data_pool(data_pool, input_data, target_data, key_name=None):
#                 order = torch.argsort(target_data)
#                 input_data = input_data[order, :]
#                 target_data = target_data[order]
#
#                 count = 0
#                 amount = input_data.shape[0] // split_num
#                 indices = list(range(input_data.shape[0]))
#                 split_data_indices_list = []
#                 for split_idx in range(split_num):
#                     start_idx = count
#                     end_idx = count + amount
#                     if end_idx > input_data.shape[0] - 1:
#                         end_idx = input_data.shape[0] - 1
#                     split_data_indices = indices[start_idx: end_idx]
#                     split_data_indices_list.append(split_data_indices)
#                     count += amount
#                 for pool_idx in range(pool_size):
#                     data_indices = []
#
#                     for i in range(pick_num):
#                         pick_data_indices = split_data_indices_list[random.randint(0, len(split_data_indices_list) - 1)]
#                         data_indices += pick_data_indices
#                         split_data_indices_list.remove(pick_data_indices)
#                     random.shuffle(data_indices)
#                     local_data_number = len(data_indices)
#                     # train_test_split
#                     train_data_indices = data_indices[:601]
#                     test_data_indices = data_indices[601:]
#                     local_train_data_number = len(train_data_indices)
#                     local_test_data_number = len(test_data_indices)
#
#                     train_batch_data_indices_list = DataLoader.separate_list(train_data_indices, self.batch_size)
#                     test_batch_data_indices_list = DataLoader.separate_list(test_data_indices, self.batch_size)
#
#
#
#                     # batch_data_indices_list = DataLoader.separate_list(data_indices, self.batch_size)
#                     # local_data = []
#                     # for batch_data_indices in batch_data_indices_list:
#                     #
#                     #     batch_input_data = input_data[batch_data_indices].reshape([-1] + self.input_data_shape).float()
#                     #     batch_target_data = target_data[batch_data_indices]
#                     #     local_data.append((batch_input_data, batch_target_data))
#
#                     local_train_data, local_test_data = [], []
#                     for batch_data_indices in train_batch_data_indices_list:
#
#                         batch_input_data = input_data[batch_data_indices].reshape([-1] + self.input_data_shape).float()
#                         batch_target_data = target_data[batch_data_indices]
#                         local_train_data.append((batch_input_data, batch_target_data))
#
#                     data_pool[pool_idx]['local_training_data'] = local_train_data
#                     data_pool[pool_idx]['local_training_number'] = local_train_data_number
#                     data_pool[pool_idx]['data_name'] = str(pool_idx)
#
#                     for batch_data_indices in test_batch_data_indices_list:
#                         batch_input_data = input_data[batch_data_indices].reshape([-1] + self.input_data_shape).float()
#                         batch_target_data = target_data[batch_data_indices]
#                         local_test_data.append((batch_input_data, batch_target_data))
#
#                     data_pool[pool_idx]['local_test_data'] = local_test_data
#                     data_pool[pool_idx]['local_test_number'] = local_test_data_number
#                     data_pool[pool_idx]['data_name'] = str(pool_idx)
#             data_pool = [{} for _ in range(self.pool_size)]
#             #
#             # create_data_pool(data_pool, train_input_data, train_target_data, 'local_training')
#             #
#             # create_data_pool(data_pool, test_input_data, test_target_data, 'local_test')
#             create_data_pool(data_pool, total_input_data, total_target_data)
#
#             self.data_pool = data_pool
#             np.save(file_path, self)
#     def allocate(self, client_list):
#
#         choose_data_pool_item_indices = np.random.choice(list(range(self.pool_size)), len(client_list), replace=False)
#         for idx, client in enumerate(client_list):
#             data_pool_item = self.data_pool[choose_data_pool_item_indices[idx]]
#             client.update_data(choose_data_pool_item_indices[idx],
#                                data_pool_item['local_training_data'],
#                                data_pool_item['local_training_number'],
#                                data_pool_item['local_test_data'],
#                                data_pool_item['local_test_number'])
# # import os, json
# # import gzip
# # import numpy as np
# #
# # NAME=[]
# # def load_mnist(path, kind='train'):
# #
# #
# #     """Load MNIST data from `path`"""
# #     labels_path = os.path.join(path,
# #                                '%s-labels-idx1-ubyte.gz'
# #                                % kind)
# #     images_path = os.path.join(path,
# #                                '%s-images-idx3-ubyte.gz'
# #                                % kind)
# #
# #     with gzip.open(labels_path, 'rb') as lbpath:
# #         labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
# #                                offset=8)
# #
# #     with gzip.open(images_path, 'rb') as imgpath:
# #         images = np.frombuffer(imgpath.read(), dtype=np.uint8,
# #                                offset=16).reshape(len(labels), 784)
# #
# #     return images, labels
# #
# #
# # def generate_dataset():
# #
# #   X_train, y_train = load_mnist('raw_data/fashion', kind='train')
# #   X_test, y_test = load_mnist('raw_data/fashion', kind='t10k')
# #
# #
# #   # some simple normalization
# #   mu = np.mean(X_train.astype(np.float32), 0)
# #   sigma = np.std(X_train.astype(np.float32), 0)
# #
# #   X_train = (X_train.astype(np.float32) - mu)/(sigma+0.001)
# #   X_test = (X_test.astype(np.float32) - mu)/(sigma+0.001)
# #
# #   return X_train.tolist(), y_train.tolist(), X_test.tolist(), y_test.tolist()
# #
# #
# # def main():
# #     train_output = "./train/mytrain.json"
# #     test_output = "./test/mytest.json"
# #
# #
# #     X_train, y_train, X_test, y_test = generate_dataset()
# #
# #
# #     # Create data structure
# #     train_data = {'users': [], 'user_data':{}, 'num_samples':[]}
# #     test_data = {'users': [], 'user_data':{}, 'num_samples':[]}
# #
# #
# #     # label 0: T-shirt(top); 2: pullover; 6: Shirt
# #     X_trains=[[] for i in range(10)]
# #     y_trains = [[] for i in range(10)]
# #     for idx, item in enumerate(X_train):
# #         i=y_train[idx]
# #         X_trains[i].append(X_train[idx])
# #         y_trains[i].append(y_train[idx])
# #
# #     X_tests = [[] for i in range(10)]
# #     y_tests = [[] for i in range(10)]
# #     for idx, item in enumerate(X_test):
# #         i=y_test[idx]
# #         X_tests[i].append(X_test[idx])
# #         y_tests[i].append(y_test[idx])
# #     label_dict={0:'T-shirt', 2:'pullover', 6:'shirt'}
# #     selected=[0,2,6]
# #     cvt_labels= {}
# #     for i in range(len(selected)):
# #         cvt_labels[selected[i]]=i
# #     for i in selected:
# #         train_len=len(X_trains[i])
# #         print("training set for {}: {}".format(i,train_len))
# #         test_len = len(X_tests[i])
# #         uname=label_dict[i]
# #         train_data['users'].append(uname)
# #         train_data['user_data'][uname] = {'x': X_trains[i], 'y': [cvt_labels[lb] for lb in y_trains[i]]}
# #         train_data['num_samples'].append(train_len)
# #         test_data['users'].append(uname)
# #         test_data['user_data'][uname] = {'x': X_tests[i], 'y': [cvt_labels[lb] for lb in y_tests[i]]}
# #         test_data['num_samples'].append(test_len)
# #
# #     with open(train_output,'w') as outfile:
# #         json.dump(train_data, outfile)
# #     with open(test_output, 'w') as outfile:
# #         json.dump(test_data, outfile)
# #
# #
# # if __name__ == "__main__":
# #     main()

from utils.DataLoader import DataLoader
import utils
import os
import torch
import torchvision
import random
from torchvision import transforms as transforms
import numpy as np
import copy
from torch.utils.data import ConcatDataset


class DataLoader_fashion(DataLoader):
    def __init__(self,
                 split_num=200,
                 pick_num=2,
                 batch_size=100,
                 input_require_shape=None,
                 shuffle=True,
                 pool_size=None,
                 recreate=False,
                 params=None,
                 *args,
                 **kwargs):
        if params is not None:
            if pool_size is not None:
                split_num = pool_size * params['N']
                pick_num = params['N']
            else:
                split_num = int(params['C'] * params['N'])
                pick_num = params['N']
            batch_size = params['batch_size']
        if split_num % pick_num != 0:
            raise RuntimeError('split_num must be divisible by the number of pick_num.')
        pool_size = split_num // pick_num
        name = 'Fashion_pool_' + str(pool_size) + 'split_' + str(split_num) + 'pick' + str(
            pick_num) + '_batchsize_' + str(batch_size) + '_sort_split_input_require_shape_' + str(input_require_shape)
        nickname = 'fashion B' + str(batch_size) + ' S' + str(split_num) + ' P' + str(pick_num) + ' N' + str(pool_size)
        super().__init__(name, nickname, pool_size, batch_size, input_require_shape)

        file_path = utils.pool_folder_path + name + '.npy'

        if os.path.exists(file_path) and (recreate == False):
            data_loader = np.load(file_path, allow_pickle=True).item()
            for attr in list(data_loader.__dict__.keys()):
                setattr(self, attr, data_loader.__dict__[attr])
            print('Successfully Read the Data Pool.')
        else:
            transform = transforms.Compose(
                [transforms.ToTensor()])
            trainset = torchvision.datasets.FashionMNIST(root=utils.data_folder_path, train=True,
                                                         download=True, transform=transform)
            trainloader = torch.utils.data.DataLoader(trainset, batch_size=trainset.data.shape[0],
                                                      shuffle=True, num_workers=1)
            testset = torchvision.datasets.FashionMNIST(root=utils.data_folder_path, train=False,
                                                        download=True, transform=transform)
            testloader = torch.utils.data.DataLoader(testset, batch_size=testset.data.shape[0],
                                                     shuffle=False, num_workers=1)
            totalset = ConcatDataset([trainset, testset])
            totalloader = torch.utils.data.DataLoader(totalset, batch_size=len(totalset),
                                                      shuffle=True, num_workers=1)
            for i, (input_data, targets) in enumerate(trainloader):
                train_input_data = input_data
                train_target_data = targets
            for i, (input_data, targets) in enumerate(testloader):
                test_input_data = input_data
                test_target_data = targets

            for i, (input_data, targets) in enumerate(totalloader):
                total_input_data = input_data
                total_target_data = targets


            self.cal_data_shape(train_input_data.shape)
            self.target_class_num = 10


            self.total_training_number = len(trainset)
            self.total_test_number = len(testset)
            self.output_size = 10
            self.model4data = 'mlp'
            self.task_name = 'fashion_classification'

            def create_data_pool(data_pool, input_data, target_data):
                order = torch.argsort(target_data)
                input_data = input_data[order, :]
                target_data = target_data[order]

                count = 0
                amount = input_data.shape[0] // split_num
                indices = list(range(input_data.shape[0]))
                split_data_indices_list = []
                for split_idx in range(split_num):
                    start_idx = count
                    end_idx = count + amount
                    if end_idx > input_data.shape[0] - 1:
                        end_idx = input_data.shape[0] - 1
                    split_data_indices = indices[start_idx: end_idx]
                    split_data_indices_list.append(split_data_indices)
                    count += amount
                for pool_idx in range(pool_size):
                    data_indices = []

                    for i in range(pick_num):
                        pick_data_indices = split_data_indices_list[random.randint(0, len(split_data_indices_list) - 1)]
                        data_indices += pick_data_indices
                        split_data_indices_list.remove(pick_data_indices)
                    random.shuffle(data_indices)
                    local_data_number = len(data_indices)
                    train_test_split_idx = int(len(data_indices) * self.total_training_number / (
                                self.total_training_number + self.total_test_number))

                    train_indices = data_indices[:train_test_split_idx]
                    test_indices = data_indices[train_test_split_idx:]

                    local_train_number = len(train_indices)
                    local_test_number = len(test_indices)

                    train_batch_data_indices_list = DataLoader.separate_list(train_indices, self.batch_size)
                    test_batch_data_indices_list = DataLoader.separate_list(test_indices, self.batch_size)

                    local_train_data = []
                    for batch_data_indices in train_batch_data_indices_list:
                        batch_input_data = input_data[batch_data_indices].reshape([-1] + self.input_data_shape).float()

                        batch_target_data = target_data[batch_data_indices]
                        local_train_data.append((batch_input_data, batch_target_data))

                    local_test_data = []
                    for batch_data_indices in test_batch_data_indices_list:
                        batch_input_data = input_data[batch_data_indices].reshape([-1] + self.input_data_shape).float()

                        batch_target_data = target_data[batch_data_indices]
                        local_test_data.append((batch_input_data, batch_target_data))

                    data_pool[pool_idx]['local_training_data'] = local_train_data
                    data_pool[pool_idx]['local_test_data'] = local_test_data
                    data_pool[pool_idx]['local_training_number'] = local_train_number
                    data_pool[pool_idx]['local_test_number'] = local_test_number

                    data_pool[pool_idx]['data_name'] = str(pool_idx)

            data_pool = [{} for _ in range(self.pool_size)]
            # local_training local_test
            create_data_pool(data_pool, total_input_data, total_target_data)
            self.data_pool = data_pool
            np.save(file_path, self)

    def allocate(self, client_list):
        choose_data_pool_item_indices = list(range(self.pool_size))
        for idx, client in enumerate(client_list):
            data_pool_item = self.data_pool[choose_data_pool_item_indices[idx]]
            client.update_data(choose_data_pool_item_indices[idx],
                               data_pool_item['local_training_data'],
                               data_pool_item['local_training_number'],
                               data_pool_item['local_test_data'],
                               data_pool_item['local_test_number'])
# class DataLoader_fashion(DataLoader):
#     def __init__(self,
#                  split_num=200,
#                  pick_num=2,
#                  batch_size=100,
#                  input_require_shape=None,
#                  shuffle=True,
#                  pool_size=None,
#                  recreate=False,
#                  params=None,
#                  *args,
#                  **kwargs):
#         if params is not None:
#             if pool_size is not None:
#                 split_num = pool_size * params['N']
#                 pick_num = params['N']
#             else:
#                 split_num = int(params['C'] * params['N'])
#                 pick_num = params['N']
#             batch_size = params['batch_size']
#         if split_num % pick_num != 0:
#             raise RuntimeError('split_num must be divisible by the number of pick_num.')
#         pool_size = split_num // pick_num
#         name = 'Fashion_pool_' + str(pool_size) + 'split_' + str(split_num) + 'pick' + str(
#             pick_num) + '_batchsize_' + str(batch_size) + '_sort_split_input_require_shape_' + str(input_require_shape)
#         nickname = 'fashion B' + str(batch_size) + ' S' + str(split_num) + ' P' + str(pick_num) + ' N' + str(pool_size)
#         super().__init__(name, nickname, pool_size, batch_size, input_require_shape)
#
#         file_path = utils.pool_folder_path + name + '.npy'
#
#         if os.path.exists(file_path) and (recreate == False):
#             data_loader = np.load(file_path, allow_pickle=True).item()
#             for attr in list(data_loader.__dict__.keys()):
#                 setattr(self, attr, data_loader.__dict__[attr])
#             print('Successfully Read the Data Pool.')
#         else:
#             transform = transforms.Compose(
#                 [transforms.ToTensor()])
#             trainset = torchvision.datasets.FashionMNIST(root=utils.data_folder_path, train=True,
#                                                          download=True, transform=transform)
#             trainloader = torch.utils.data.DataLoader(trainset, batch_size=trainset.data.shape[0],
#                                                       shuffle=True, num_workers=1)
#             testset = torchvision.datasets.FashionMNIST(root=utils.data_folder_path, train=False,
#                                                         download=True, transform=transform)
#             testloader = torch.utils.data.DataLoader(testset, batch_size=testset.data.shape[0],
#                                                      shuffle=False, num_workers=1)
#             # global_training_data = torch.utils.data.DataLoader(copy.deepcopy(trainset),
#             #                                                    batch_size=self.batch_size,
#             #                                                    shuffle=True, num_workers=1)
#             # global_test_data = torch.utils.data.DataLoader(copy.deepcopy(testset),
#             #                                                batch_size=self.batch_size,
#             #                                                shuffle=False, num_workers=1)
#             # # modify
#             # num_samples = len(trainset)
#             # noise_ratio = 0.5
#             # num_samples_to_modify = int(num_samples * noise_ratio)
#             # indices_to_modify = np.random.choice(num_samples, num_samples_to_modify, replace=False)
#             # num_classes = 10
#             # for idx in indices_to_modify:
#             #     new_label = np.random.randint(0, num_classes)  # Generate random incorrect label
#             #     while new_label == trainset.targets[idx]:  # Ensure the new label is different from the original one
#             #         new_label = np.random.randint(0, num_classes)
#             #     trainset.targets[idx] = new_label
#
#             for i, (input_data, targets) in enumerate(trainloader):
#                 train_input_data = input_data
#                 train_target_data = targets
#             for i, (input_data, targets) in enumerate(testloader):
#                 test_input_data = input_data
#                 test_target_data = targets
#
#             self.cal_data_shape(train_input_data.shape)
#
#             self.target_class_num = 10
#
#             self.global_training_data = []
#             self.global_test_data = []
#             # for (input_data, targets) in global_training_data:
#             #     self.global_training_data.append((input_data.reshape([-1] + self.input_data_shape), targets))
#             # for (input_data, targets) in global_test_data:
#             #     self.global_test_data.append((input_data.reshape([-1] + self.input_data_shape), targets))
#             self.total_training_number = len(trainset)
#             self.total_test_number = len(testset)
#             self.output_size = 10
#             self.model4data = 'mlp'
#             self.task_name = 'fashion_classification'
#
#             def create_data_pool(data_pool, input_data, target_data, key_name):
#                 order = torch.argsort(target_data)
#                 input_data = input_data[order, :]
#                 target_data = target_data[order]
#
#                 count = 0
#                 amount = input_data.shape[0] // split_num
#                 indices = list(range(input_data.shape[0]))
#                 split_data_indices_list = []
#                 for split_idx in range(split_num):
#                     start_idx = count
#                     end_idx = count + amount
#                     if end_idx > input_data.shape[0] - 1:
#                         end_idx = input_data.shape[0] - 1
#                     split_data_indices = indices[start_idx: end_idx]
#                     split_data_indices_list.append(split_data_indices)
#                     count += amount
#                 for pool_idx in range(pool_size):
#                     data_indices = []
#
#                     for i in range(pick_num):
#                         pick_data_indices = split_data_indices_list[random.randint(0, len(split_data_indices_list) - 1)]
#                         data_indices += pick_data_indices
#                         split_data_indices_list.remove(pick_data_indices)
#                     random.shuffle(data_indices)
#                     local_data_number = len(data_indices)
#
#                     batch_data_indices_list = DataLoader.separate_list(data_indices, self.batch_size)
#                     local_data = []
#                     for batch_data_indices in batch_data_indices_list:
#                         batch_input_data = input_data[batch_data_indices].reshape([-1] + self.input_data_shape).float()
#                         batch_target_data = target_data[batch_data_indices]
#                         local_data.append((batch_input_data, batch_target_data))
#
#                     data_pool[pool_idx][key_name + '_data'] = local_data
#                     data_pool[pool_idx][key_name + '_number'] = local_data_number
#                     data_pool[pool_idx]['data_name'] = str(pool_idx)
#
#             data_pool = [{} for _ in range(self.pool_size)]
#
#             create_data_pool(data_pool, train_input_data, train_target_data, 'local_training')
#
#             create_data_pool(data_pool, test_input_data, test_target_data, 'local_test')
#             self.data_pool = data_pool
#             np.save(file_path, self)
#
#     def allocate(self, client_list):
#
#         choose_data_pool_item_indices = np.random.choice(list(range(self.pool_size)), len(client_list), replace=False)
#         for idx, client in enumerate(client_list):
#             data_pool_item = self.data_pool[choose_data_pool_item_indices[idx]]
#             client.update_data(choose_data_pool_item_indices[idx],
#                                data_pool_item['local_training_data'],
#                                data_pool_item['local_training_number'],
#                                data_pool_item['local_test_data'],
#                                data_pool_item['local_test_number'])






# import os, json
# import gzip
# import numpy as np
#
# NAME=[]
# def load_mnist(path, kind='train'):
#
#
#     """Load MNIST data from `path`"""
#     labels_path = os.path.join(path,
#                                '%s-labels-idx1-ubyte.gz'
#                                % kind)
#     images_path = os.path.join(path,
#                                '%s-images-idx3-ubyte.gz'
#                                % kind)
#
#     with gzip.open(labels_path, 'rb') as lbpath:
#         labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
#                                offset=8)
#
#     with gzip.open(images_path, 'rb') as imgpath:
#         images = np.frombuffer(imgpath.read(), dtype=np.uint8,
#                                offset=16).reshape(len(labels), 784)
#
#     return images, labels
#
#
# def generate_dataset():
#
#   X_train, y_train = load_mnist('raw_data/fashion', kind='train')
#   X_test, y_test = load_mnist('raw_data/fashion', kind='t10k')
#
#
#   # some simple normalization
#   mu = np.mean(X_train.astype(np.float32), 0)
#   sigma = np.std(X_train.astype(np.float32), 0)
#
#   X_train = (X_train.astype(np.float32) - mu)/(sigma+0.001)
#   X_test = (X_test.astype(np.float32) - mu)/(sigma+0.001)
#
#   return X_train.tolist(), y_train.tolist(), X_test.tolist(), y_test.tolist()
#
#
# def main():
#     train_output = "./train/mytrain.json"
#     test_output = "./test/mytest.json"
#
#
#     X_train, y_train, X_test, y_test = generate_dataset()
#
#
#     # Create data structure
#     train_data = {'users': [], 'user_data':{}, 'num_samples':[]}
#     test_data = {'users': [], 'user_data':{}, 'num_samples':[]}
#
#
#     # label 0: T-shirt(top); 2: pullover; 6: Shirt
#     X_trains=[[] for i in range(10)]
#     y_trains = [[] for i in range(10)]
#     for idx, item in enumerate(X_train):
#         i=y_train[idx]
#         X_trains[i].append(X_train[idx])
#         y_trains[i].append(y_train[idx])
#
#     X_tests = [[] for i in range(10)]
#     y_tests = [[] for i in range(10)]
#     for idx, item in enumerate(X_test):
#         i=y_test[idx]
#         X_tests[i].append(X_test[idx])
#         y_tests[i].append(y_test[idx])
#     label_dict={0:'T-shirt', 2:'pullover', 6:'shirt'}
#     selected=[0,2,6]
#     cvt_labels= {}
#     for i in range(len(selected)):
#         cvt_labels[selected[i]]=i
#     for i in selected:
#         train_len=len(X_trains[i])
#         print("training set for {}: {}".format(i,train_len))
#         test_len = len(X_tests[i])
#         uname=label_dict[i]
#         train_data['users'].append(uname)
#         train_data['user_data'][uname] = {'x': X_trains[i], 'y': [cvt_labels[lb] for lb in y_trains[i]]}
#         train_data['num_samples'].append(train_len)
#         test_data['users'].append(uname)
#         test_data['user_data'][uname] = {'x': X_tests[i], 'y': [cvt_labels[lb] for lb in y_tests[i]]}
#         test_data['num_samples'].append(test_len)
#
#     with open(train_output,'w') as outfile:
#         json.dump(train_data, outfile)
#     with open(test_output, 'w') as outfile:
#         json.dump(test_data, outfile)
#
#
# if __name__ == "__main__":
#     main()